{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "O9I9rz0pTamX"
   },
   "source": [
    "# Sequential Training Example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "EiowR0WNTd1C"
   },
   "source": [
    "[Sequential training](https://arxiv.org/abs/1811.01088v2) involves fine-tuning a language-encoding model (e.g. BERT) on one task (the \"intermediate\" task), and then again on a second task (the \"target\" task). In many cases, the right choice of intermediate task can improve the performance on the target task compared to fine-tuning only on the target task.\n",
    "\n",
    "Between the two phases of training, we are going to carry over the language encoding model, and not the task heads.\n",
    "\n",
    "--- \n",
    "\n",
    "In this notebook, we will:\n",
    "\n",
    "* Train a RoBERTa base model on MNLI, and the further fine-tune the model on RTE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rXbD_U1_VDnw"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tC9teoazUnW8"
   },
   "source": [
    "#### Install dependencies\n",
    "\n",
    "First, we will install libraries we need for this code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8aU3Z9szuMU9"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!git clone https://github.com/nyu-mll/jiant.git\n",
    "%cd jiant\n",
    "!pip install -r requirements-no-torch.txt\n",
    "!pip install --no-deps -e ./"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KGJcCmRzU1Qb"
   },
   "source": [
    "#### Download data\n",
    "\n",
    "Next, we will download MNLI and RTE data. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jKCz8VksvFlN"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "%cd /content\n",
    "# Download MNLI and RTE data\n",
    "!PYTHONPATH=/content/jiant python jiant/jiant/scripts/download_data/runscript.py \\\n",
    "    download \\\n",
    "    --tasks mnli rte \\\n",
    "    --output_path=/content/tasks/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rQKSAhYzVIlv"
   },
   "source": [
    "## `jiant` Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "v88oXqmBvFuK"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"/content/jiant\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ibmMT7CXv1_P"
   },
   "outputs": [],
   "source": [
    "import jiant.proj.main.tokenize_and_cache as tokenize_and_cache\n",
    "import jiant.proj.main.export_model as export_model\n",
    "import jiant.proj.main.scripts.configurator as configurator\n",
    "import jiant.proj.main.runscript as main_runscript\n",
    "import jiant.shared.caching as caching\n",
    "import jiant.utils.python.io as py_io\n",
    "import jiant.utils.display as display\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "HPZHyLOlVp07"
   },
   "source": [
    "#### Download model\n",
    "\n",
    "Next, we will download a `roberta-base` model. This also includes the tokenizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "K06qUGjkKWa7"
   },
   "outputs": [],
   "source": [
    "export_model.export_model(\n",
    "    hf_pretrained_model_name_or_path=\"roberta-base\",\n",
    "    output_base_path=\"./models/roberta-base\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dV-T-8r1V0wf"
   },
   "source": [
    "#### Tokenize and cache\n",
    "\n",
    "With the model and data ready, we can now tokenize and cache the inputs features for our tasks. This converts the input examples to tokenized features ready to be consumed by the model, and saved them to disk in chunks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "22bNWQajO4zm"
   },
   "outputs": [],
   "source": [
    "# Tokenize and cache each task\n",
    "for task_name in [\"mnli\", \"rte\"]:\n",
    "    tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
    "        task_config_path=f\"./tasks/configs/{task_name}_config.json\",\n",
    "        hf_pretrained_model_name_or_path=\"roberta-base\",\n",
    "        output_dir=f\"./cache/{task_name}\",\n",
    "        phases=[\"train\", \"val\"],\n",
    "    ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JJ-mWSQQWJsw"
   },
   "source": [
    "We can inspect the first examples of the first chunk of each task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "iLk_X0KypUyr"
   },
   "outputs": [],
   "source": [
    "row = caching.ChunkedFilesDataCache(\"./cache/mnli/train\").load_chunk(0)[0][\"data_row\"]\n",
    "print(row.input_ids)\n",
    "print(row.tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2n00e6Xrp1bI"
   },
   "outputs": [],
   "source": [
    "row = caching.ChunkedFilesDataCache(\"./cache/rte/val\").load_chunk(0)[0][\"data_row\"]\n",
    "print(row.input_ids)\n",
    "print(row.tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3MBuH19IWOr0"
   },
   "source": [
    "#### Writing a run config\n",
    "\n",
    "Here we are going to write what we call a `jiant_task_container_config`. This configuration file basically defines a lot of the subtleties of our training pipeline, such as what tasks we will train on, do evaluation on, batch size for each task. The new version of `jiant` leans heavily toward explicitly specifying everything, for the purpose of inspectability and leaving minimal surprises for the user, even as the cost of being more verbose.\n",
    "\n",
    "Since we are training in two phases, we will need to write two run configs - one for MNLI, and one for RTE. (This might seem tedious, but note that these can be easily reusable across different combinations of intermediate and target tasks.)\n",
    "\n",
    "We use a helper \"Configurator\" to write out a `jiant_task_container_config`, since most of our setup is pretty standard. \n",
    "\n",
    "We start with the MNLI config:\n",
    "\n",
    "**Depending on what GPU your Colab session is assigned to, you may need to lower the train batch size.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "pQYtl7xTKsiP"
   },
   "outputs": [],
   "source": [
    "jiant_run_config = configurator.SimpleAPIMultiTaskConfigurator(\n",
    "    task_config_base_path=\"./tasks/configs\",\n",
    "    task_cache_base_path=\"./cache\",\n",
    "    train_task_name_list=[\"mnli\"],\n",
    "    val_task_name_list=[\"mnli\"],\n",
    "    train_batch_size=8,\n",
    "    eval_batch_size=16,\n",
    "    epochs=0.1,\n",
    "    num_gpus=1,\n",
    ").create_config()\n",
    "os.makedirs(\"./run_configs/\", exist_ok=True)\n",
    "py_io.write_json(jiant_run_config, \"./run_configs/mnli_run_config.json\")\n",
    "display.show_json(jiant_run_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-UF501yoXHBi"
   },
   "source": [
    "To briefly go over the major components of the `jiant_task_container_config`:\n",
    "\n",
    "* `task_config_path_dict`: The paths to the task config files we wrote above.\n",
    "* `task_cache_config_dict`: The paths to the task features caches we generated above.\n",
    "* `sampler_config`: Determines how to sample from different tasks during training.\n",
    "* `global_train_config`: The number of total steps and warmup steps during training.\n",
    "* `task_specific_configs_dict`: Task-specific arguments for each task, such as training batch size and gradient accumulation steps.\n",
    "* `taskmodels_config`: Task-model specific arguments for each task-model, including what tasks use which model.\n",
    "* `metric_aggregator_config`: Determines how to weight/aggregate the metrics across multiple tasks."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4T-9kuT75V5J"
   },
   "source": [
    "Next, we will write the equivalent for RTE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "lQrT9pm24DZY"
   },
   "outputs": [],
   "source": [
    "jiant_run_config = configurator.SimpleAPIMultiTaskConfigurator(\n",
    "    task_config_base_path=\"./tasks/configs\",\n",
    "    task_cache_base_path=\"./cache\",\n",
    "    train_task_name_list=[\"rte\"],\n",
    "    val_task_name_list=[\"rte\"],\n",
    "    train_batch_size=8,\n",
    "    eval_batch_size=16,\n",
    "    epochs=0.5,\n",
    "    num_gpus=1,\n",
    ").create_config()\n",
    "os.makedirs(\"./run_configs/\", exist_ok=True)\n",
    "py_io.write_json(jiant_run_config, \"./run_configs/rte_run_config.json\")\n",
    "display.show_json(jiant_run_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BBKkvXzdYPqZ"
   },
   "source": [
    "#### Start training\n",
    "\n",
    "Finally, we can start our training run. \n",
    "\n",
    "Before starting training, the script also prints out the list of parameters in our model. In the first phase, we are simply training on MNLI."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "JdwWPgjQWx6I"
   },
   "outputs": [],
   "source": [
    "run_args = main_runscript.RunConfiguration(\n",
    "    jiant_task_container_config_path=\"./run_configs/mnli_run_config.json\",\n",
    "    output_dir=\"./runs/mnli\",\n",
    "    hf_pretrained_model_name_or_path=\"roberta-base\",\n",
    "    model_path=\"./models/roberta-base/model/modle.p\",\n",
    "    model_config_path=\"./models/roberta-base/model/config.json\",\n",
    "    learning_rate=1e-5,\n",
    "    eval_every_steps=500,\n",
    "    do_train=True,\n",
    "    do_val=True,\n",
    "    do_save=True,\n",
    "    force_overwrite=True,\n",
    ")\n",
    "main_runscript.run_loop(run_args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Dk2KQwNt5fTG"
   },
   "source": [
    "The above run saves the best model weights to `./runs/mnli/best_model.p`. Now, we will pick up from those saved model weights and start training on RTE. In addition to changing the `model_path`, we also set `model_load_mode=\"partial\"`. This tells `jiant` that we will not be loading and reusing the task heads from the previous run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hmtRJc-84Msh"
   },
   "outputs": [],
   "source": [
    "run_args = main_runscript.RunConfiguration(\n",
    "    jiant_task_container_config_path=\"./run_configs/rte_run_config.json\",\n",
    "    output_dir=\"./runs/mnli___rte\",\n",
    "    hf_pretrained_model_name_or_path=\"roberta-base\",\n",
    "    model_path=\"./runs/mnli/best_model.p\",  # Loading the best model\n",
    "    model_load_mode=\"partial\",\n",
    "    model_config_path=\"./models/roberta-base/model/config.json\",\n",
    "    learning_rate=1e-5,\n",
    "    eval_every_steps=500,\n",
    "    do_train=True,\n",
    "    do_val=True,\n",
    "    force_overwrite=True,\n",
    ")\n",
    "main_runscript.run_loop(run_args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4SXcuHFIYp6Y"
   },
   "source": [
    "Finally, we should see the validation scores RTE. You can compare these to just training on RTE and should see a good margin of improvement."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "jiant STILTs Example",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}